{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n",
      "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.\n",
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-08-20 09:39:59,650] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:14<00:00,  7.46s/it]\n"
     ]
    }
   ],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "model_path = \"../luwen_baichuan/output/zju_model_0818_110k\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Human:在上班路上脚踝关节骨折算不算是工伤，企业应该负什么责任？  Assistant: 脚踝关节骨折属于工伤范畴，企业应该负起相应的责任。根据《中华人民共和国工伤保险条例》，工伤是指在工作时间和工作场所内，因工作原因受到的伤害或者患病。因此，如果在上班路上发生脚踝关节骨折，可以认定为工伤。企业应该及时为员工提供医疗救治，并按照相关规定支付工伤保险待遇。同时，企业也应该加强安全管理，确保员工在工作期间的安全。如果企业没有为员工购买工伤保险，员工可以向劳动保障部门申请工伤认定和工伤保险待遇。\n"
     ]
    }
   ],
   "source": [
    "prompt = \"在上班路上脚踝关节骨折算不算是工伤，企业应该负什么责任？\"\n",
    "inputs = tokenizer(f'</s>Human:{prompt} </s>Assistant: ', return_tensors='pt')\n",
    "inputs = inputs.to('cuda')\n",
    "pred = model.generate(**inputs, max_new_tokens=300, repetition_penalty=1)\n",
    "print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
